import time
import joblib
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler, label_binarize, OneHotEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from mdp_random import data_handle
from sklearn.metrics import roc_auc_score as roc
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import confusion_matrix
import matplotlib as mpl

from mdp_all import mdp_data


def plot_roc(labels, predict_prob, auc, macro, macro_recall, weighted):
    # 创建一个1行2列的画布
    figure, axes = plt.subplots(ncols=1, nrows=2, figsize=(6.5, 6.5), dpi=100)
    # 绘图对象
    ax1 = axes[0]
    ax2 = axes[1]

    # 选择ax1
    plt.sca(ax1)
    false_positive_rate, true_positive_rate, thresholds = metrics.roc_curve(labels, predict_prob)  # 真阳性，假阳性，阈值
    roc_auc = metrics.auc(false_positive_rate, true_positive_rate)  # 计算AUC值
    print('AUC=' + str(roc_auc))
    plt.title('ROC')
    plt.plot(false_positive_rate, true_positive_rate, 'b', label='AUC = %0.4f' % roc_auc)
    plt.legend(loc='lower right')
    plt.plot([0, 1], [0, 1], 'r--')
    plt.ylabel('TPR（真阳性率）')
    plt.xlabel('FPR（伪阳性率）')

    # 选择ax2
    plt.sca(ax2)
    plt.axis('off')
    plt.title('模型评价指标', y=-0.1)
    # 解决中文乱码和正负号问题
    mpl.rcParams["font.sans-serif"] = ["SimHei"]
    mpl.rcParams["axes.unicode_minus"] = False

    col_labels = ['准确率', '精确率', '召回率', 'f1值']
    row_labels = ['实际']
    table_vals = [[auc, macro, macro_recall, weighted]]
    row_colors = ['red', 'pink', 'green', 'gold']
    table = plt.table(cellText=table_vals, colWidths=[0.18 for x in col_labels],
                      rowLabels=row_labels, colLabels=col_labels,
                      rowColours=row_colors, colColours=row_colors,
                      loc="center")
    table.set_fontsize(14)
    table.scale(1.5, 1.5)
    plt.show()
    # plt.savefig('figures/PC5.png') #将ROC图片进行保存
def LDA_algorithm(filename,num=0):
    datasets, labels, count = data_handle(filename)  # 对数据集进行处理
    print("Running for Method: LDA")
    X = datasets[:]
    y = labels[:]
    print("len of X", len(X))
    print("no of column", len(datasets[0]))
    labelencoder_y = LabelEncoder()
    y = labelencoder_y.fit_transform(y)
    # Splitting the dataset into the Training set and Test set
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0)
    # Feature Scaling
    sc = StandardScaler()
    X_train = sc.fit_transform(X_train)
    X_test = sc.transform(X_test)
    # 降维
    lda = LDA(n_components=1)
    X_train = lda.fit_transform(X_train, y_train)
    X_test = lda.transform(X_test)
    classifier = DecisionTreeClassifier(criterion='entropy', random_state=0)

    # 训练
    classifier.fit(X_train, y_train)
    joblib.dump(classifier, 'files/lda.pkl')

    # Predicting the Test set results
    y_pred = classifier.predict(X_test)

    auc = metrics.accuracy_score(y_test, y_pred)
    macro = metrics.precision_score(y_test, y_pred, average='macro')
    micro = metrics.precision_score(y_test, y_pred, average='micro')
    macro_recall = metrics.recall_score(y_test, y_pred, average='macro')
    weighted = metrics.f1_score(y_test, y_pred, average='weighted')
    print('准确率:', auc)  # 预测准确率输出
    print('宏平均精确率:', macro)  # 预测宏平均精确率输出
    print('微平均精确率:', micro)  # 预测微平均精确率输出
    print('宏平均召回率:', macro_recall)  # 预测宏平均召回率输出
    print('平均F1-score:', weighted)  # 预测平均f1-score输出
    print('混淆矩阵输出:\n', metrics.confusion_matrix(y_test, y_pred))  # 混淆矩阵输出
    print('分类报告:', metrics.classification_report(y_test, y_pred))  # 分类报告输出
    # Plot of a ROC curve for a specific class
    if(num==0):
        plot_roc(y_test, y_pred, auc, macro, macro_recall, weighted)  # 绘制ROC曲线并求出AUC值
    else:
        mdp_data.add_data(mdp_data,y_test, y_pred, auc, macro, macro_recall, weighted)


if __name__ == '__main__':
    LDA_algorithm('MDP/ClassLevel10000+.csv')